import json
import random
import re
import sys
import os
import time
import openai

from tqdm import tqdm
from openai import AzureOpenAI
from datasets import load_dataset
import numpy as np

# Azure OpenAI Configuration
api_version = "2024-02-15-preview"
config_dict = {
    "api_key": "YOUR_OPENAI_API_KEY",
    "api_version": api_version,
    "azure_endpoint": "https://your-azure-openai-endpoint/",
}

# -----------------------------------------------------------------------------
# Helper to format ad_details into text block
# -----------------------------------------------------------------------------


def format_ad_details(ad_dict: dict) -> str:
    lines = []
    if dur := ad_dict.get("Duration"):
        lines.append(f"Video Duration: {dur}")
    if brand := ad_dict.get("Brand"):
        lines.append(f"Brand: {brand}")
    if title := ad_dict.get("Title"):
        lines.append(f"Title: \"{title}\"")
    if pace := ad_dict.get("Pace"):
        lines.append(f"Pace: {pace}")
    if orient := ad_dict.get("Orientation"):
        lines.append(f"Orientation: {orient}")

    for idx, scene in enumerate(ad_dict.get("Scenes", []), 1):
        desc = scene.get("Description", "")
        emo = scene.get("Emotions", "")
        tags = scene.get("Tags", "")
        lines.append(f"Scene {idx}: {desc}. Emotions: {emo}. Tags: {tags}.")

    return "\n".join(lines)


# -----------------------------------------------------------------------------
# Prompt template
# -----------------------------------------------------------------------------


baseline_system_prompt = (
    "You are an expert evaluator of video advertisement memorability. "
    "Given structured details about a video advertisement, your job is to "
    "predict its long-term memorability score for a general audience on a "
    "0–100 scale (where 0 means instantly forgotten and 100 means "
    "unforgettable even weeks later). "
    "Think about narrative strength, emotional resonance, uniqueness, brand "
    "fit, pacing, and any visual or auditory hooks described.\n\n"
    "Return your answer in this exact two-line format:\n"
    "Reason: <brief justification>\n"
    "Answer: <integer 0-100>\n"
    "Include **only one** numeric value (the score) and place it directly after 'Answer:'."
)



# -----------------------------------------------------------------------------
# GPT helper
# -----------------------------------------------------------------------------


def verbalize(prompt: str, sys_prompt: str):
    client = AzureOpenAI(
        api_key=config_dict["api_key"],
        api_version=config_dict["api_version"],
        azure_endpoint=config_dict["azure_endpoint"],
    )

    resp = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": prompt},
        ],
        max_tokens=300,
        temperature=0.7,
    )

    return resp.choices[0].message.content.strip()


# -----------------------------------------------------------------------------
# Load LAMBDA dataset
# -----------------------------------------------------------------------------


cache_dir = None
if len(sys.argv) > 1:
    cache_dir = sys.argv[1]

dataset = load_dataset(
    "behavior-in-the-wild/LAMBDA",
    split="test",
    cache_dir=cache_dir or "/path/to/hf_cache",
)

records = list(dataset)

# -----------------------------------------------------------------------------
# Pre-compute embeddings for similarity search
# -----------------------------------------------------------------------------
openai.api_key = os.getenv("OPENAI_API_KEY")
EMBED_MODEL = "text-embedding-3-small"

def _get_embeddings(texts, batch_size=96):
    vecs=[]
    for start in range(0,len(texts),batch_size):
        chunk=texts[start:start+batch_size]
        try:
            resp=openai.Embeddings.create(model=EMBED_MODEL,input=chunk)
            resp.data.sort(key=lambda x: x.index)
            vecs.extend([np.array(d.embedding,dtype=np.float32) for d in resp.data])
        except Exception:
            return [None]*len(texts)
        time.sleep(0.1)
    return vecs

def _cos(u,v):
    return float(np.dot(u,v)/(np.linalg.norm(u)*np.linalg.norm(v)+1e-8))

_all_blocks=[
    format_ad_details(json.loads(r["ad_details"]) if isinstance(r["ad_details"],str) else r["ad_details"])
    for r in records
]
try:
    all_embeddings=_get_embeddings(_all_blocks)
except Exception as e:
    print(f"Embedding generation failed, using random sampling: {e}")
    all_embeddings=[None]*len(records)

# -----------------------------------------------------------------------------
# Inference loop (10 runs per sample)
# -----------------------------------------------------------------------------


response_dict = []

for idx, record in enumerate(tqdm(records, desc="Nopersona-10x")):
    ad = record["ad_details"]
    if isinstance(ad, str):
        try:
            ad = json.loads(ad)
        except json.JSONDecodeError:
            ad = {}

    target_text = format_ad_details(ad)

    # Select 5 random examples
    other_idx = list(range(len(records)))
    other_idx.remove(idx)
    if all_embeddings[idx] is not None and all_embeddings[other_idx[0]] is not None:
        sims=[(_cos(all_embeddings[idx],all_embeddings[j]),j) for j in other_idx]
        sims.sort(reverse=True)
        sample_idx=[j for _,j in sims[:5]]
    else:
        sample_idx = random.sample(other_idx, k=min(5, len(other_idx)))

    example_lines = []
    for si in sample_idx:
        ex = records[si]
        ex_ad = ex["ad_details"]
        if isinstance(ex_ad, str):
            try:
                ex_ad = json.loads(ex_ad)
            except json.JSONDecodeError:
                ex_ad = {}
        desc = format_ad_details(ex_ad)
        score = int(round(float(ex["recall_score"]) * 100))
        example_lines.append(f"{desc}\nScore: {score}\n")

    examples_block = "\n---\n".join(example_lines)

    prompt = (
        f"Below are five example video ads with their memorability scores (0-100). "
        f"Assess the sixth ad and predict its score.\n\n{examples_block}\n---\n{target_text}\n\nQuestion: What is the long-term memorability score of this video?"
    )

    all_responses = []
    all_predictions = []

    for run in range(10):
        resp = verbalize(prompt, baseline_system_prompt)
        nums = re.findall(r"\b\d+(?:\.\d+)?\b", resp)
        pred = float(nums[-1]) if nums else None
        all_responses.append(resp)
        if pred is not None:
            all_predictions.append(pred)

    mean_prediction = float(np.mean(all_predictions)) if all_predictions else None

    response_entry = {
        "video_id": int(record["video_id"]),
        "youtube_id": record["youtube_id"],
        "ground_truth": float(record["recall_score"]) * 100,
        "prompt": prompt,
        "all_responses": all_responses,
        "all_predictions": all_predictions,
        "mean_prediction": mean_prediction,
    }

    response_dict.append(response_entry)

    if idx % 20 == 0:
        with open("ngpt_lambda_nopersona_ten_results.json", "w") as f:
            json.dump(response_dict, f, indent=2)

# Final save
with open("gpt_lambda_nopersona_ten_results.json", "w") as f:
    json.dump(response_dict, f, indent=2) 